import json
import re
from dataclasses import dataclass
from datetime import timedelta
from pathlib import Path

import pandas as pd

from rdagent.app.data_science.conf import DS_RD_SETTING
from rdagent.components.coder.CoSTEER.evaluators import (
    CoSTEEREvaluator,
    CoSTEERSingleFeedback,
)
from rdagent.components.coder.data_science.conf import get_clear_ws_cmd, get_ds_env
from rdagent.components.coder.data_science.utils import remove_eda_part
from rdagent.core.evolving_framework import QueriedKnowledge
from rdagent.core.experiment import FBWorkspace, Task
from rdagent.log import rdagent_logger as logger
from rdagent.log.timer import RD_Agent_TIMER_wrapper
from rdagent.scenarios.data_science.dev.runner import DSRunnerCoSTEERSettings
from rdagent.scenarios.data_science.test_eval import (
    MLETestEval,
    NoTestEvalError,
    get_test_eval,
)
from rdagent.utils.agent.tpl import T
from rdagent.utils.agent.workflow import build_cls_from_json_with_retry
from rdagent.utils.fmt import shrink_text

DIRNAME = Path(__file__).absolute().resolve().parent


@dataclass
class DSRunnerFeedback(CoSTEERSingleFeedback):
    """
    Feedback for Data Science CoSTEER evaluation.
    This feedback is used to evaluate the code and execution of the Data Science CoSTEER task.
    """

    acceptable: bool | None = None
    hyperparameter_tuning_decision: bool | None = None
    hyperparameter_tuning_suggestion: str | None = None
    score: str | None = None

    def is_acceptable(self) -> bool:
        if self.acceptable is not None:
            return self.acceptable
        return super().is_acceptable()

    def __str__(self) -> str:
        parts = [
            "### Execution",
            str(self.execution),
            "### Return Check",
            self.return_checking if self.return_checking is not None else "No return checking",
            "### Code",
            str(self.code),
            "### Validation Score",
            f"{self.score}" if self.score else "Not available",
            "### Final Decision",
            f"This implementation is {'PASSED' if self.acceptable else 'FAILED'}.",
        ]
        if self.hyperparameter_tuning_decision:
            parts.append("### Hyperparameter Tuning Suggestion")
            parts.append(str(self.hyperparameter_tuning_suggestion))
        return "\n".join(parts)


DSCoSTEEREvalFeedback = DSRunnerFeedback  # FIXME: Alias for backward compatibility


class DSRunnerEvaluator(CoSTEEREvaluator):

    def evaluate(
        self,
        target_task: Task,
        implementation: FBWorkspace,
        gt_implementation: FBWorkspace,
        queried_knowledge: QueriedKnowledge = None,
        **kwargs,
    ) -> DSRunnerFeedback:
        env = get_ds_env(
            extra_volumes={
                f"{DS_RD_SETTING.local_data_path}/{self.scen.competition}": T(
                    "scenarios.data_science.share:scen.input_path"
                ).r()
            },
            running_timeout_period=self.scen.real_full_timeout(),
        )

        stdout = implementation.execute(
            env=env, entry=get_clear_ws_cmd()
        )  # Remove previous submission and scores files generated by worklfow.

        # get previous runner loops
        task_info = target_task.get_task_information()
        queried_former_failed_knowledge = (
            queried_knowledge.task_to_former_failed_traces[task_info] if queried_knowledge is not None else []
        )[0]

        # execute workflow
        result = implementation.run(env=env, entry="python -m coverage run main.py")
        stdout = result.get_truncated_stdout()
        execute_ret_code = result.exit_code
        implementation.running_info.running_time = result.running_time

        match = re.search(r"(.*?)=== Start of EDA part ===(.*)=== End of EDA part ===", stdout, re.DOTALL)
        eda_output = match.groups()[1] if match else None
        if eda_output is None:
            eda_output = "No EDA output."
        implementation.inject_files(
            **{
                "EDA.md": eda_output,
                "stdout.txt": result.stdout if DSRunnerCoSTEERSettings().dump_stdout_type == "full" else stdout,
            }
        )  # stdout.txt is used for debugging. not used in any other place.
        stdout = remove_eda_part(stdout)
        stdout += f"The code executed {'successfully' if execute_ret_code == 0 else 'failed'}. {'The EDA output is removed from the stdout. ' if eda_output else ''}"

        # Check score file
        score_fp = implementation.workspace_path / "scores.csv"
        score_ret_code = 0
        score_check_text = ""
        if not score_fp.exists():
            logger.warning("Metrics file (scores.csv) is not generated!")
            score_check_text = "[Error] Metrics file (scores.csv) is not generated!"
            score_ret_code = 1
        else:
            try:
                score_df = pd.read_csv(score_fp, index_col=0)
                model_set_in_scores = set(score_df.index)
                model_set_in_folder = set(
                    f[:-3] for f in implementation.file_dict.keys() if re.match(r"^model_(?!test)\w+\.py$", f)
                )

                # Check model names (index)
                # in Pipeline task, we only check ensemble in scores.csv
                if DS_RD_SETTING.coder_on_whole_pipeline:
                    if not score_df.index.is_unique:
                        score_check_text += "\n[Error] The file 'scores.csv' contains duplicate model names."
                        score_ret_code = 1
                    if "ensemble" not in model_set_in_scores:
                        score_check_text += "\n[Error] The file 'scores.csv' doesn't contain the ensemble model."
                        score_ret_code = 1
                    if score_ret_code != 0:
                        score_check_text += f"The dataframe in file 'scores.csv' is:\n{score_df}"
                else:
                    if model_set_in_scores != model_set_in_folder.union({"ensemble"}):
                        score_check_text += f"\n[Error] The scores dataframe does not contain the correct model names as index.\ncorrect model names are: {model_set_in_folder.union({'ensemble'})}\nscore_df is:\n{score_df}"
                        score_ret_code = 1

                # Check metric name (columns) - case insensitive
                if [col.lower() for col in score_df.columns.tolist()] != [self.scen.metric_name.lower()]:
                    score_check_text += f"\n[Error] The scores dataframe does not contain the correct column names.\nCorrect columns is: ['{self.scen.metric_name}']\nBut got: {score_df.columns.tolist()}"
                    score_ret_code = 1

            except Exception as e:
                logger.error(f"Error in checking the scores.csv file: {e}")
                score_check_text += f"\n[Error] in checking the scores.csv file: {e}\nscores.csv's content:\n-----\n{score_fp.read_text()}\n-----"
                score_ret_code = 1

        # DockerEnv for MLEBench submission validation
        submission_check_out = ""
        submission_ret_code = 0
        test_eval = get_test_eval()

        if test_eval.enabled(self.scen.competition):
            submission_check_out, submission_ret_code = test_eval.valid(self.scen.competition, implementation)
            stdout += f"\n### Submission check:\n{submission_check_out}\nIf Submission check returns a 'Submission is valid' or similar message, despite some warning messages, you should still consider the submission as valid and give a positive final decision. "

        # Whether to enable hyperparameter tuning check
        # 1. This is the first loop of evaluation.
        if DS_RD_SETTING.only_first_loop_enable_hyperparameter_tuning:
            c1 = len(queried_knowledge.task_to_former_failed_traces[target_task.get_task_information()][0]) == 0
        else:
            c1 = True

        # 2. The current time spent on runner is less than the time limit ratio for runner timeout.
        time_spent_ratio = implementation.running_info.running_time / env.conf.running_timeout_period
        c2 = time_spent_ratio < DS_RD_SETTING.time_ratio_limit_to_enable_hyperparameter_tuning

        # 3. Only enable hyperparameter tuning during the merge stage if configured.
        # TODO: it is not restricted in merge stage now for fast implementation.
        timer = RD_Agent_TIMER_wrapper.timer
        res_time = timer.remain_time()
        if DS_RD_SETTING.only_enable_tuning_in_merge:
            c3 = res_time <= timedelta(hours=DS_RD_SETTING.merge_hours)
        else:
            c3 = True

        # 4. The current time spent on global is less than the time limit ratio for whole timeout.
        if timer.all_duration is not None and res_time is not None:
            res_ratio = res_time / timer.all_duration
            c4 = res_ratio <= DS_RD_SETTING.res_time_ratio_limit_to_enable_hyperparameter_tuning
        else:
            c4 = True

        # Only enable hyperparameter tuning check if all conditions are met
        enable_hyperparameter_tuning_check = c1 and c2 and c3 and c4

        system_prompt = T(".prompts:DSCoSTEER_eval.system").r(
            scenario=self.scen.get_scenario_all_desc(eda_output=implementation.file_dict.get("EDA.md", None)),
            task_desc=target_task.get_task_information(),
            enable_hyperparameter_tuning_check=enable_hyperparameter_tuning_check,
        )
        user_prompt = T(".prompts:DSCoSTEER_eval.user").r(
            code=implementation.all_codes,
            change_summary=implementation.change_summary,
            stdout=shrink_text(stdout),
            time_spent=f"{implementation.running_info.running_time:.2f} seconds",
            timeout=f"{env.conf.running_timeout_period} seconds",
            percent_of_timeout_used=f"{time_spent_ratio * 100:.2f}%",
            queried_former_failed_knowledge=queried_former_failed_knowledge,
        )

        feedback = build_cls_from_json_with_retry(
            DSRunnerFeedback,
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            # init_kwargs_update_func=DSRunnerFeedback.val_and_update_init_dict,
        )
        try:
            feedback.score = score_df.loc["ensemble"].iloc[0] if score_ret_code == 0 else None
        except:
            logger.error("Failed to get the score from scores.csv.")
            feedback.score = None
        feedback.final_decision = feedback.acceptable and (
            not feedback.hyperparameter_tuning_decision
        )  # If hyperparameter_tuning_decision is None, it's considered as False, so the final_decision dependents on the acceptable

        if feedback and not DS_RD_SETTING.coder_on_whole_pipeline:
            # remove unused files
            implementation.execute(env=env, entry="python -m coverage json -o coverage.json")
            coverage_report_path = implementation.workspace_path / "coverage.json"
            if coverage_report_path.exists():
                used_files = set(json.loads(coverage_report_path.read_text())["files"].keys())
                coverage_report_path.unlink()
                logger.info(f"All used scripts: {used_files}")

                use_one_model = False
                for f in used_files:
                    if f.startswith("model_") and "test" not in f:
                        use_one_model = True
                        break

                if not use_one_model:
                    feedback.acceptable = feedback.final_decision = False
                    logger.warning("No model script is used in `main.py`.")
                    feedback.code += "\n[Error] No model script is used in `main.py`."

                all_python_files = set(Path(implementation.workspace_path).rglob("*.py"))
                must_have_files = ["load_data.py", "feature.py", "ensemble.py"]

                unused_files = [
                    py_file.name
                    for py_file in all_python_files
                    if not (py_file.name in used_files or py_file.name.endswith("test.py"))
                ]
                if unused_files:
                    logger.warning(f"Unused scripts: {unused_files}")
                    error_files = set(unused_files).intersection(set(must_have_files))
                    if error_files:
                        feedback.acceptable = feedback.final_decision = False
                        logger.warning(f"{error_files} must be used in `main.py`.")
                        feedback.code += f"\n[Error] {error_files} must be used in `main.py`."
                    elif use_one_model:
                        logger.info("Remove unused scripts.")
                        implementation.inject_files(**{file: implementation.DEL_KEY for file in unused_files})

        if score_ret_code != 0:
            feedback.acceptable = feedback.final_decision = False
            feedback.return_checking += "\n" + score_check_text
        if submission_ret_code != 0:
            feedback.acceptable = feedback.final_decision = False
            feedback.return_checking += "\nSubmission file check failed."
        return feedback
